# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import re
from enum import Enum
from typing import List, Optional

import numpy as np

from pyrit.models import MessagePiece, Score
from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer
from pyrit.score.scorer_prompt_validator import ScorerPromptValidator


class PlagiarismMetric(Enum):
    """Enum representing different plagiarism detection metrics."""

    LCS = "lcs"
    LEVENSHTEIN = "levenshtein"
    JACCARD = "jaccard"


class PlagiarismScorer(FloatScaleScorer):
    """
    A scorer that measures plagiarism by computing word-level similarity
    between the AI response and a reference text.

    This scorer implements three similarity metrics:
    1. Word-level longest common subsequence (LCS)
    2. Word-level Levenshtein similarity
    3. Word-level n-gram Jaccard similarity
    """

    _default_validator: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"])

    def __init__(
        self,
        reference_text: str,
        metric: PlagiarismMetric = PlagiarismMetric.LCS,
        n: int = 5,
        validator: Optional[ScorerPromptValidator] = None,
    ) -> None:
        """
        Initialize the PlagiarismScorer.

        Args:
            reference_text (str): The reference text to compare against.
            metric (PlagiarismMetric): The plagiarism detection metric to use. Defaults to PlagiarismMetric.LCS.
            n (int): The n-gram size for n-gram similarity. Defaults to 5.
            validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None.
        """
        super().__init__(validator=validator or self._default_validator)

        self.reference_text = reference_text
        self.metric = metric
        self.n = n

    def _tokenize(self, text: str) -> List[str]:
        """
        Tokenize text using whitespace-based tokenization (case-insensitive).

        Returns:
            List[str]: List of lowercase tokens with punctuation removed.
        """
        text = text.lower()
        text = re.sub(r"[^\w\s]", "", text)
        return text.split()

    def _lcs_length(self, a: List[str], b: List[str]) -> int:
        """
        Compute the length of the Longest Common Subsequence at word level.

        Returns:
            int: Length of the longest common subsequence.
        """
        dp = np.zeros((len(a) + 1, len(b) + 1), dtype=int)
        for i in range(1, len(a) + 1):
            for j in range(1, len(b) + 1):
                if a[i - 1] == b[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
        return dp[len(a)][len(b)]

    def _levenshtein_distance(self, a: List[str], b: List[str]) -> int:
        """
        Compute Levenshtein edit distance at word level.

        Returns:
            int: The Levenshtein distance between the two token lists.
        """
        dp = np.zeros((len(a) + 1, len(b) + 1), dtype=int)
        for i in range(len(a) + 1):
            dp[i][0] = i
        for j in range(len(b) + 1):
            dp[0][j] = j
        for i in range(1, len(a) + 1):
            for j in range(1, len(b) + 1):
                cost = 0 if a[i - 1] == b[j - 1] else 1
                dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost)
        return dp[len(a)][len(b)]

    def _ngram_set(self, tokens: List[str], n: int) -> set:
        """
        Generate a set of n-grams from token list.

        Returns:
            set: Set of n-gram tuples.
        """
        return set(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))

    def _plagiarism_score(
        self,
        response: str,
        reference: str,
        metric: PlagiarismMetric = PlagiarismMetric.LCS,
        n: int = 5,
    ) -> float:

        tokens_response = self._tokenize(response)
        tokens_reference = self._tokenize(reference)
        response_len = len(tokens_response)
        reference_len = len(tokens_reference)

        if response_len == 0 or reference_len == 0:
            return 0.0

        # If reference is in response, all three metrics should be 1.0
        if reference in response:
            return 1.0

        # Compute the LCS metric (normalized by reference length)
        if metric.value == "lcs":
            lcs_len = self._lcs_length(tokens_reference, tokens_response)
            score = lcs_len / reference_len
            return score

        # Compute the Levenshtein metric (normalized by max length)
        elif metric.value == "levenshtein":
            lev_dist = self._levenshtein_distance(tokens_reference, tokens_response)
            max_len = max(reference_len, response_len)
            score = 1 - (lev_dist / max_len)
            return score

        # Compute the Jaccard metric (normalized by number of n-grams in reference)
        elif metric.value == "jaccard":
            ref_ngrams = self._ngram_set(tokens_reference, n) if reference_len >= n else set()
            res_ngrams = self._ngram_set(tokens_response, n) if response_len >= n else set()
            if not ref_ngrams:
                return 0.0
            score = len(ref_ngrams & res_ngrams) / len(ref_ngrams)
            return score

        else:
            raise ValueError("metric must be 'lcs', 'levenshtein', or 'jaccard'")

    async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]:
        """
        Scores the AI response against the reference text using the specified metric.

        Args:
            message_piece (MessagePiece): The piece to score.
            objective (Optional[str]): Not applicable for this scorer.

        Returns:
            list[Score]: A list containing the computed score.
        """
        response_to_evaluate = message_piece.converted_value
        score_value = self._plagiarism_score(response_to_evaluate, self.reference_text, metric=self.metric, n=self.n)

        return [
            Score(
                score_value=str(score_value),
                score_value_description=f"Plagiarism score using '{self.metric.value}' metric",
                score_metadata=None,
                score_type="float_scale",
                score_rationale="Score is deterministic.",
                message_piece_id=message_piece.id,
            )
        ]
